import sys
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
import torch.utils.data as data
from sklearn.preprocessing import OneHotEncoder
import random
from copy import deepcopy
from model.models import *
from model.resnet import *
from model.vgg import *
from model.googlenet import *
from torch.utils.data import Dataset
from instance import *

class CIFAR10_Augmentention(Dataset):   # augmentation
    def __init__(self, images, mincom, com, true_labels, id):
        self.images = images
        self.mincom = mincom
        self.com = com
        self.true_labels = true_labels
        self.id = id
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        #each_image=self.images[index]
        each_image = self.transform(self.images[index])
        each_mincom = self.mincom[index]
        each_com = self.com[index]
        each_true_label = self.true_labels[index]
        idx = self.id[index]
        
        return each_image, each_mincom, each_com, each_true_label, idx
def class_prior(complementary_labels):
    return np.bincount(complementary_labels) / len(complementary_labels)
# import cifar10 data
def load_cifar10(batch_size, pre_model='resnet18',label_num=3):
    data_type='cifar10'
    train_X, train_Y, test_X, test_Y = next(extract_data(data_type)) # load data

    if pre_model == 'resnet18':
        partialize_net = resnet18().to(device)
    elif pre_model == 'resnet34':
        partialize_net = resnet34().to(device)
    elif pre_model == 'vgg16':
        partialize_net = VGGNet().to(device)
    elif pre_model == 'googlenet':
        partialize_net = googlenet().to(device)
    # generate comp
    mincom= generate_mincompl_labels(train_X=train_X,train_Y=train_Y,model=partialize_net,batch_size=batch_size,label_num=label_num,data_type=data_type,pre_model=pre_model)
    com= generate_compl_labels(labels=train_Y)
    one_hot_mincom = F.one_hot(mincom)
    one_hot_com = F.one_hot(com)
    
    id = torch.arange(len(train_Y))

    test_transform = transforms.Compose(
            [transforms.ToTensor(),  # change to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])   # normalize (no need change)
    
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    train_Y=torch.topk(torch.tensor(train_Y), 1)[1].squeeze(1).clone().detach()
    ccp_com = class_prior(com)
    ccp_mincom = class_prior(mincom)
    train_data = CIFAR10_Augmentention(train_X, mincom, com, train_Y, id)
    train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size = batch_size, shuffle = True, num_workers=4, pin_memory=True)
    return train_loader, test_loader, one_hot_mincom, one_hot_com,ccp_com,ccp_mincom


    

